{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-18T16:42:19.830501Z",
     "start_time": "2017-12-18T16:42:19.825056Z"
    }
   },
   "source": [
    "# TensorFlow to Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reference\n",
    "- https://github.com/myutwo150/keras-inception-resnet-v2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pre-trained tensorflow models\n",
    "\n",
    "- https://github.com/davidsandberg/facenet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:45:41.868059Z",
     "start_time": "2017-12-27T10:45:41.207755Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import re\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "import sys\n",
    "sys.path.append('../code/')\n",
    "from inception_resnet_v1 import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:45:42.700055Z",
     "start_time": "2017-12-27T10:45:42.696473Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "tf_model_dir = '../model/tf/20170512-110547/'\n",
    "npy_weights_dir = '../model/keras/npy_weights/'\n",
    "weights_dir = '../model/keras/weights/'\n",
    "model_dir = '../model/keras/model/'\n",
    "\n",
    "weights_filename = 'facenet_keras_weights.h5'\n",
    "model_filename = 'facenet_keras.h5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:45:45.459543Z",
     "start_time": "2017-12-27T10:45:45.456524Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "os.makedirs(npy_weights_dir, exist_ok=True)\n",
    "os.makedirs(weights_dir, exist_ok=True)\n",
    "os.makedirs(model_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:45:47.200605Z",
     "start_time": "2017-12-27T10:45:47.168120Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# regex for renaming the tensors to their corresponding Keras counterpart\n",
    "re_repeat = re.compile(r'Repeat_[0-9_]*b')\n",
    "re_block8 = re.compile(r'Block8_[A-Za-z]')\n",
    "\n",
    "def get_filename(key):\n",
    "    filename = str(key)\n",
    "    filename = filename.replace('/', '_')\n",
    "    filename = filename.replace('InceptionResnetV1_', '')\n",
    "\n",
    "    # remove \"Repeat\" scope from filename\n",
    "    filename = re_repeat.sub('B', filename)\n",
    "\n",
    "    if re_block8.match(filename):\n",
    "        # the last block8 has different name with the previous 5 occurrences\n",
    "        filename = filename.replace('Block8', 'Block8_6')\n",
    "\n",
    "    # from TF to Keras naming\n",
    "    filename = filename.replace('_weights', '_kernel')\n",
    "    filename = filename.replace('_biases', '_bias')\n",
    "\n",
    "    return filename + '.npy'\n",
    "\n",
    "\n",
    "def extract_tensors_from_checkpoint_file(filename, output_folder):\n",
    "    reader = tf.train.NewCheckpointReader(filename)\n",
    "\n",
    "    for key in reader.get_variable_to_shape_map():\n",
    "        # not saving the following tensors\n",
    "        if key == 'global_step':\n",
    "            continue\n",
    "        if 'AuxLogit' in key:\n",
    "            continue\n",
    "\n",
    "        # convert tensor name into the corresponding Keras layer weight name and save\n",
    "        path = os.path.join(output_folder, get_filename(key))\n",
    "        arr = reader.get_tensor(key)\n",
    "        np.save(path, arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:45:48.440655Z",
     "start_time": "2017-12-27T10:45:48.253325Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "extract_tensors_from_checkpoint_file(tf_model_dir+'model-20170512-110547.ckpt-250000', npy_weights_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:45:52.208432Z",
     "start_time": "2017-12-27T10:45:49.347963Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = InceptionResNetV1()\n",
    "# model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-27T10:46:23.916513Z",
     "start_time": "2017-12-27T10:45:53.621944Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading numpy weights from ../model/keras/npy_weights/\n",
      "Saving weights...\n",
      "Saving model...\n"
     ]
    }
   ],
   "source": [
    "print('Loading numpy weights from', npy_weights_dir)\n",
    "for layer in model.layers:\n",
    "    if layer.weights:\n",
    "        weights = []\n",
    "        for w in layer.weights:\n",
    "            weight_name = os.path.basename(w.name).replace(':0', '')\n",
    "            weight_file = layer.name + '_' + weight_name + '.npy'\n",
    "            weight_arr = np.load(os.path.join(npy_weights_dir, weight_file))\n",
    "            weights.append(weight_arr)\n",
    "        layer.set_weights(weights)\n",
    "\n",
    "print('Saving weights...')\n",
    "model.save_weights(os.path.join(weights_dir, weights_filename))\n",
    "print('Saving model...')\n",
    "model.save(os.path.join(model_dir, model_filename))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3.6",
   "language": "python",
   "name": "python3.6"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": "block",
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
